# -*- coding: utf-8 -*-
"""
@author: dengpanxiao(@126.com)

@file: linear_regression.py

@time: 17/6/5 下午7:11

@desc: 线性回归

"""
import numpy as np
import matplotlib.pyplot as plt
import sklearn.linear_model.LinearRegression as LinearRegression


def runplt():
    plt.figure()
    plt.title(u'diameter-cost curver')
    plt.xlabel(u'diameter')
    plt.ylabel(u'cost')
    plt.axis([0, 25, 0, 25])
    plt.grid(True)
    return plt

plt = runplt()
X = [[6], [8], [10], [14], [18]]
y = [[7], [9], [13], [17.5], [18]]
plt.plot(X, y, 'k.')
plt.show()

# 训练模型
model = LinearRegression()
model.fit(X, y)
print("预测一张12英寸匹萨价格：￥%.2f" % model.predict(np.array([12]).reshape(-1, 1))[0])